{
 "metadata": {
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9-final"
  },
  "orig_nbformat": 2,
  "kernelspec": {
   "name": "python3",
   "display_name": "Python 3.7.9 64-bit",
   "metadata": {
    "interpreter": {
     "hash": "08bf4b89015df3fa62ec3e5cebfe3c3e5a181176f7e37ebd57af46c3a62ed99e"
    }
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2,
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "import torch\n",
    "from torch import nn, optim\n",
    "import d2lzh_pytorch as d2l\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "def vgg_block(num_convs, in_channels, out_channels):\n",
    "    blk = []\n",
    "    for i in range(num_convs):\n",
    "        if i == 0:\n",
    "            blk.append(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))\n",
    "        \n",
    "        else:\n",
    "            blk.append(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1))\n",
    "        \n",
    "        blk.append(nn.ReLU())\n",
    "    \n",
    "    blk.append(nn.MaxPool2d(kernel_size=2, stride=2))\n",
    "\n",
    "    return nn.Sequential(*blk)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "vgg_block_1 output shape :  torch.Size([1, 64, 112, 112])\n",
      "vgg_block_2 output shape :  torch.Size([1, 128, 56, 56])\n",
      "vgg_block_3 output shape :  torch.Size([1, 256, 28, 28])\n",
      "vgg_block_4 output shape :  torch.Size([1, 512, 14, 14])\n",
      "vgg_block_5 output shape :  torch.Size([1, 512, 7, 7])\n",
      "fc output shape :  torch.Size([1, 10])\n"
     ]
    }
   ],
   "source": [
    "# 模块串联数个vgg_block, 其超参数由变量conv_arch定义\n",
    "conv_arch = ((1, 1, 64), (1, 64, 128), (2, 128, 256), (2, 256, 512), (2, 512, 512))\n",
    "fc_features = 512 * 7 * 7 # c * w * h\n",
    "fc_hidden_units = 4096\n",
    "\n",
    "def vgg(conv_arch, fc_features, fc_hidden_units=4096):\n",
    "    net = nn.Sequential()\n",
    "\n",
    "    for i, (num_convs, in_channels, out_channels) in enumerate(conv_arch):\n",
    "        net.add_module(\"vgg_block_\" + str(i+1), vgg_block(num_convs, in_channels, out_channels))\n",
    "    \n",
    "    net.add_module(\"fc\", nn.Sequential(\n",
    "        d2l.FlattenLayer(),\n",
    "        nn.Linear(fc_features, fc_hidden_units),\n",
    "        nn.ReLU(),\n",
    "        nn.Dropout(0.5),\n",
    "        nn.Linear(fc_hidden_units, fc_hidden_units),\n",
    "        nn.ReLU(),\n",
    "        nn.Dropout(0.5),\n",
    "        nn.Linear(fc_hidden_units, 10)\n",
    "    ))\n",
    "\n",
    "    return net\n",
    "\n",
    "\n",
    "\n",
    "net = vgg(conv_arch, fc_features, fc_hidden_units)\n",
    "X = torch.rand(1, 1, 224, 224)\n",
    "\n",
    "for name, blk in net.named_children():\n",
    "    X = blk(X)\n",
    "    print(name, 'output shape : ', X.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "Sequential(\n  (vgg_block_1): Sequential(\n    (0): Conv2d(1, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n    (1): ReLU()\n    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n  )\n  (vgg_block_2): Sequential(\n    (0): Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n    (1): ReLU()\n    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n  )\n  (vgg_block_3): Sequential(\n    (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n    (1): ReLU()\n    (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n    (3): ReLU()\n    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n  )\n  (vgg_block_4): Sequential(\n    (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n    (1): ReLU()\n    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n    (3): ReLU()\n    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n  )\n  (vgg_block_5): Sequential(\n    (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n    (1): ReLU()\n    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n    (3): ReLU()\n    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n  )\n  (fc): Sequential(\n    (0): FlattenLayer()\n    (1): Linear(in_features=3136, out_features=512, bias=True)\n    (2): ReLU()\n    (3): Dropout(p=0.5, inplace=False)\n    (4): Linear(in_features=512, out_features=512, bias=True)\n    (5): ReLU()\n    (6): Dropout(p=0.5, inplace=False)\n    (7): Linear(in_features=512, out_features=10, bias=True)\n  )\n)\n"
     ]
    }
   ],
   "source": [
    "# 出于测试目的，减小通道数，更窄的网络\n",
    "ratio = 8\n",
    "small_conv_arch = [(1, 1, 64//ratio), (1, 64//ratio, 128//ratio), (2, 128//ratio, 256//ratio), (2, 256//ratio, 512//ratio), (2, 512//ratio, 512//ratio)]\n",
    "net = vgg(small_conv_arch, fc_features//ratio, fc_hidden_units//ratio)\n",
    "print(net)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "training on  cuda\n",
      "epoch 1, loss 0.8349, train acc 0.683, test acc 0.877, time 117.0 sec\n",
      "epoch 2, loss 0.3257, train acc 0.883, test acc 0.894, time 114.2 sec\n",
      "epoch 3, loss 0.2773, train acc 0.900, test acc 0.906, time 114.3 sec\n",
      "epoch 4, loss 0.2494, train acc 0.909, test acc 0.908, time 114.3 sec\n",
      "epoch 5, loss 0.2238, train acc 0.918, test acc 0.914, time 115.2 sec\n"
     ]
    }
   ],
   "source": [
    "batch_size = 64\n",
    "train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=224)\n",
    "lr, num_epochs = 0.001, 5\n",
    "optimizer = torch.optim.Adam(net.parameters(), lr=lr)\n",
    "d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ]
}